{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a616c3c1",
   "metadata": {},
   "source": [
    "# Run Fine-tune CoT on OpenAI using our `oai` module\n",
    "\n",
    "This notebook contains code to (1) generate reasoning samples from teacher models (e.g., GPT-3 175B `text-davinci-002`), (2) fine-tune student models (e.g., GPT-3 0.3B `ada`) and (3) generate and evaluate samples from fine-tuned student models.\n",
    "\n",
    "- To run from scratch, first download and save original benchmark data (see README).\n",
    "- To use existing teacher-generated samples, first download and save original benchmark data and teacher completion data (see README). Then, replace the completion_key `zs_cot_test` with `zs_cot` in the code below."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a10a6ccc",
   "metadata": {},
   "source": [
    "### TODO: Set OpenAI Key\n",
    "\n",
    "Create an account on OpenAI and retrieve your API key. Experiments will incurs fees on your OpenAI account."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2587ed99",
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "openai.api_key = \"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86fa4cf8",
   "metadata": {},
   "source": [
    "### Imports and Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4a11f958",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.completion_dataset import CompletionMetadata, CompletionDataset\n",
    "from oai.inference import infer_completion_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "be7d870e",
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher_base_model = \"text-davinci-002\"  # GPT-3 (175B)\n",
    "base_model = \"ada\"                       # GPT-3 (0.3B)\n",
    "# base_model = \"babbage\"                   # GPT-3 (1.3B)\n",
    "# base_model = \"curie\"                     # GPT-3 (6.7B)\n",
    "dataset_key = \"date_understanding\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5043037d",
   "metadata": {},
   "source": [
    "## Infer teacher completions using OpenAI (generate CompletionDataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3fa53308",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note, completion_key identifies the method used to generate completions\n",
    "# Note, prediction_template selects the prediction template from those pre-defined in\n",
    "#       `oai.data.format.Formatter`.\n",
    "completion_metadata = CompletionMetadata(base_model=teacher_base_model, completion_key=\"zs_cot_test\",\n",
    "                                         dataset_key=dataset_key, prediction_template=\"zs_cot\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "16d1d824",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 369 samples from:\n",
      "/Users/itsnamgyu/code/temp/reasoning-teacher/saved/completion_data/B_text-davinci-002__C_zs_cot_test/D_date_understanding.json\n",
      "All 369 samples have been completed.\n"
     ]
    }
   ],
   "source": [
    "# Run Zero-shot-CoT step 1 (rationale generation)\n",
    "# Note, sample_indices=None means we want to infer on all samples\n",
    "completion_dataset = infer_completion_data(completion_metadata, zs_cot_step=1,\n",
    "                                           sample_indices=None, augs=1, temperature=0,\n",
    "                                           max_tokens=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "042067fa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 369 samples from:\n",
      "/Users/itsnamgyu/code/temp/reasoning-teacher/saved/completion_data/B_text-davinci-002__C_zs_cot_test/D_date_understanding.json\n",
      "All 369 samples have been completed.\n"
     ]
    }
   ],
   "source": [
    "# Run Zero-shot-CoT step 2 (answer)\n",
    "completion_dataset = infer_completion_data(completion_metadata, zs_cot_step=2,\n",
    "                                           sample_indices=None, augs=1, temperature=0,\n",
    "                                           max_tokens=128)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b19e754",
   "metadata": {},
   "source": [
    "## Load CompletionDataset and evaluate test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "995c950d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.completion_dataset import CompletionIdentifier\n",
    "from data.split import load_train_test_split \n",
    "from evaluation.evaluator import Evaluator\n",
    "from evaluation.summary import summarize_evaluation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5cf87485",
   "metadata": {},
   "outputs": [],
   "source": [
    "completion_identifier = CompletionIdentifier(teacher_base_model, \"zs_cot_test\", dataset_key)\n",
    "completion_dataset = CompletionDataset.load(completion_identifier)\n",
    "# Note, completion_metadata can be used instead of completion_identifier such as below\n",
    "# completion_dataset = CompletionDataset.load(completion_metadata)\n",
    "train, test = load_train_test_split(dataset_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c80ce146",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluator = Evaluator.for_completion_dataset(completion_dataset)\n",
    "evaluation = evaluator.evaluate_completion_dataset(completion_dataset, test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c36c1760",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sample_index</th>\n",
       "      <th>completion_index</th>\n",
       "      <th>correct</th>\n",
       "      <th>contains_answer</th>\n",
       "      <th>correct_format</th>\n",
       "      <th>complete</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>23</td>\n",
       "      <td>0</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>25</td>\n",
       "      <td>0</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>0</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   sample_index  completion_index  correct  contains_answer  correct_format   \n",
       "0             0                 0     True             True            True  \\\n",
       "1             9                 0     True             True            True   \n",
       "2            23                 0     True             True            True   \n",
       "3            25                 0     True             True            True   \n",
       "4            28                 0     True             True            True   \n",
       "\n",
       "   complete  \n",
       "0      True  \n",
       "1      True  \n",
       "2      True  \n",
       "3      True  \n",
       "4      True  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluation.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f0d58e46",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.7477477477477478,\n",
       " 'contains_answer': 0.7477477477477478,\n",
       " 'correct_format': 1.0,\n",
       " 'complete': 1.0}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summarize_evaluation(evaluation)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24187a50",
   "metadata": {},
   "source": [
    "## Create fine-tune `File` and `Finetune` using training set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bd27aebf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from oai.finetune import init_finetune, generate_finetune_data_from_completion_dataset\n",
    "from oai.utils.api_wrapper import fetch_model_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e4edefd6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Replace \"zs_cot_test\" with \"zs_cot\" to use our teacher-generated completions (see README for how to download).\n",
    "completion_identifier = CompletionIdentifier(teacher_base_model, \"zs_cot_test\", dataset_key)\n",
    "completion_dataset = CompletionDataset.load(completion_identifier)\n",
    "train, test = load_train_test_split(dataset_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ee88849d",
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_key = \"zs_cot_test_{}\".format(dataset_key)\n",
    "train_key = \"ft_cot_test\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "55825a80",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving 171 fine-tuning samples to /Users/itsnamgyu/code/temp/reasoning-teacher/saved/finetune_data/P_openai/F_zs_cot_test_date_understanding.jsonl\n"
     ]
    }
   ],
   "source": [
    "# Note, finetune_key is a unique identifier for the finetuning data and should contain the source dataset\n",
    "generate_finetune_data_from_completion_dataset(completion_dataset=completion_dataset,\n",
    "                                               prediction_template=\"ft_cot_token\",\n",
    "                                               finetune_key=finetune_key,\n",
    "                                               sample_indices=train,\n",
    "                                               only_correct=True,  # default\n",
    "                                              )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c98c87e3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "    \"prompt\": \"Yesterday was April 30, 2021. What is the date one year ago from today in MM/DD/YYYY?\\nWhich choice is true? Answer choices: (A) 05/01/1971, (B) 04/01/2020, (C) 05/15/2020, (D) 05/01/2020, (E) 05/08/2020. ###\",\n",
      "    \"completion\": \" One year ago from today would be 2020. Today is 2021. 2020 is two years ago. Two years ago from today would be 05/01/2019. --> D END\"\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# Inspect finetune data\n",
    "import json\n",
    "from paths import get_finetune_data_path\n",
    "with open(get_finetune_data_path(\"openai\", finetune_key)) as f:\n",
    "    print(json.dumps(json.loads(f.readline()), indent=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "03bf8d56",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: OpenAI File `zs_cot_test_date_understanding` already exists (likely already uploaded). Skipping.\n",
      "Warning: OpenAI Finetune for `B_ada__D_date_understanding__T_ft_cot_test` already exists. Skipping.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'B_ada__D_date_understanding__T_ft_cot_test'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Note, train_key identifies the method used to train the model, i.e., the method used to fine-tune the base model.\n",
    "init_finetune(finetune_key, base_model, dataset_key, train_key)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec628615",
   "metadata": {},
   "source": [
    "### Fetch fine-tuned `Model` id\n",
    "\n",
    "You need to keep calling this function to check if your `Finetune` is finished. Fine-tuning typically take about 5 minutes to 1 hour."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d83504c2",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No model ids to fetch\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fetch_model_ids()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea2ef518",
   "metadata": {},
   "source": [
    "### Access OpenAI metadata\n",
    "\n",
    "We use metadata files to map our identifiers (keys) to the identifier (ids) used by OpenAI objects.\n",
    "These can be accessed manually, as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "01d9f0b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from oai.utils.metadata import get_file_id, get_finetune_id, get_model_id, get_model_key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "48975b97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note that `base_model`, `dataset_key`, `train_key` are joined together to form a `model_key` which\n",
    "# identifies fine-tuned models. There is a one-to-one-to-one mapping between a model_key, Finetune object,\n",
    "# and Model object.\n",
    "model_key = get_model_key(base_model, dataset_key, train_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "e1a46489",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'file-3lwlV7lJRebTr0JTniuZ7lCX'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Note that our `finetune_key` identifies the fine-tuning \"data\", therefore is mapped to a File object\n",
    "# rather than a Finetune object.\n",
    "get_file_id(finetune_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "95162d8c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'ft-ord6Qs8vmXQI8VjVWrfNTrTs'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_finetune_id(model_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "71570d68",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'ada:ft-namgyu-ho-2023-06-11-04-37-04'"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_model_id(model_key)  # fetched by `fetch_model_ids()`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17c0460e",
   "metadata": {},
   "source": [
    "## Infer student completions\n",
    "\n",
    "We only infer test set samples for evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "df8e3ede",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note, completion_key and train_key are both \"ft_cot_test\". Recall that completion_key refers to\n",
    "# the method used to generate completions by the student model, and train_key refers to the method\n",
    "# used to train the student model.\n",
    "completion_metadata = CompletionMetadata(base_model=base_model, completion_key=\"ft_cot_test\",\n",
    "                                         dataset_key=dataset_key, finetune_key=finetune_key,\n",
    "                                         prediction_template=\"ft_cot_token\",\n",
    "                                         train_key=train_key, epoch=None)\n",
    "train, test = load_train_test_split(dataset_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "aa1a4a0e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initializing new CompletionDataset at:\n",
      "/Users/itsnamgyu/code/temp/reasoning-teacher/saved/completion_data/B_ada__C_ft_cot_test/D_date_understanding__T_ft_cot_test.json\n",
      "Inferring completions for 111 remaining samples (total=111)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Inferring completions via OpenAI: 100%|███████████████████████████| 111/111 [00:12<00:00,  8.72it/s]\n"
     ]
    }
   ],
   "source": [
    "# Note, `infer_completion_data` will find our new student model (that we fetched above) by using\n",
    "#       `base_model`, `dataset_key`, and `train_key` which is specified in `completion_metadata`.\n",
    "completion_dataset = infer_completion_data(completion_metadata, zs_cot_step=None,\n",
    "                                           sample_indices=test, augs=1, temperature=0,\n",
    "                                           max_tokens=1024)  # note, we use 1024 tokens for student inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75c623f5",
   "metadata": {},
   "source": [
    "## Evaluate student completions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "afc851ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "completion_identifier = CompletionIdentifier(base_model, completion_key=\"ft_cot_test\", dataset_key=dataset_key,\n",
    "                                             train_key=\"ft_cot_test\")\n",
    "completion_dataset = CompletionDataset.load(completion_identifier)\n",
    "train, test = load_train_test_split(dataset_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "21eac841",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "evaluator = Evaluator(dataset_key, \"ft_cot_token\")\n",
    "evaluation = evaluator.evaluate_completion_dataset(completion_dataset, test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "a4181fbb",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.12612612612612611,\n",
       " 'contains_answer': 0.12612612612612611,\n",
       " 'correct_format': 0.9819819819819819,\n",
       " 'complete': 0.9819819819819819}"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summarize_evaluation(evaluation)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
